Patrick Altmeyer
CounterfactualExplanations.jl: getting started“You cannot appeal to (algorithms). They do not listen. Nor do they bend.”
— Cathy O’Neil in Weapons of Math Destruction, 2016
We have fitted some black box classifier to divide cats and dogs. One 🐱 is friends with a lot of cool 🐶 and wants to remain part of that group. The counterfactual path below shows her how to fool the classifier:
CounterfactualExplanations.jl 📦Julia is fast, transparent, beautiful and open 🔴🟢🟣
\[ \min_{x\prime \in \mathcal{X}} h(x\prime) \ \ \ \mbox{s. t.} \ \ \ M(x\prime) = t \qquad(1)\]
\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) + \lambda h(x\prime) \qquad(2)\]
So counterfactual search is just gradient descent in the feature space 💡 Easy right?
Effective counterfactuals should meet certain criteria ✅
\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) \ \ , \ \ \forall M\in\mathcal{\widetilde{M}} \qquad(3)\]
CounterfactualExplanations.jl: getting started# Data:
using CounterfactualExplanations.Data
Random.seed!(1234)
N = 25
xs, ys = Data.toy_data_linear(N)
X = hcat(xs...)
counterfactual_data = CounterfactualData(X,ys')
# Model
using CounterfactualExplanations.Models: LogisticModel, probs
# Logit model:
w = [1.0 1.0] # true coefficients
b = 0
M = LogisticModel(w, [b])
# Randomly selected factual:
Random.seed!(123)
x = select_factual(counterfactual_data,rand(1:size(X)[2]))
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target
# Counterfactual search:
generator = GenericGenerator()
counterfactual = generate_counterfactual(x, target, counterfactual_data, M, generator)# Model:
using LinearAlgebra
Σ = Symmetric(reshape(randn(9),3,3).*0.01 + UniformScaling(1)) # MAP covariance matrix
μ = hcat(b, w)
M = CounterfactualExplanations.Models.BayesianLogisticModel(μ, Σ)
# Counterfactual search:
generator = GreedyGenerator(;δ=0.1,n=25))
counterfactual = generate_counterfactual(x, target, counterfactual_data, M, generator)using Flux, RCall
using CounterfactualExplanations, CounterfactualExplanations.Models
import CounterfactualExplanations.Models: logits, probs # import functions in order to extend
# Step 1)
struct TorchNetwork <: Models.AbstractFittedModel
nn::Any
end
# Step 2)
function logits(M::TorchNetwork, X::AbstractArray)
nn = M.nn
ŷ = rcopy(R"as_array($nn(torch_tensor(t($X))))")
ŷ = isa(ŷ, AbstractArray) ? ŷ : [ŷ]
return ŷ'
end
probs(M::TorchNetwork, X::AbstractArray)= σ.(logits(M, X))
M = TorchNetwork(R"model")import CounterfactualExplanations.Generators: ∂ℓ
using LinearAlgebra
# Countefactual loss:
function ∂ℓ(generator::AbstractGradientBasedGenerator, counterfactual_state::CounterfactualState)
M = counterfactual_state.M
nn = M.nn
x′ = counterfactual_state.x′
t = counterfactual_state.target_encoded
R"""
x <- torch_tensor($x′, requires_grad=TRUE)
output <- $nn(x)
obj_loss <- nnf_binary_cross_entropy_with_logits(output,$t)
obj_loss$backward()
"""
grad = rcopy(R"as_array(x$grad)")
return grad
endusing Flux, PyCall
using CounterfactualExplanations, CounterfactualExplanations.Models
import CounterfactualExplanations.Models: logits, probs # import functions in order to extend
# Step 1)
struct PyTorchNetwork <: Models.AbstractFittedModel
nn::Any
end
# Step 2)
function logits(M::PyTorchNetwork, X::AbstractArray)
nn = M.nn
if !isa(X, Matrix)
X = reshape(X, length(X), 1)
end
ŷ = py"$nn(torch.Tensor($X).T).detach().numpy()"
ŷ = isa(ŷ, AbstractArray) ? ŷ : [ŷ]
return ŷ
end
probs(M::PyTorchNetwork, X::AbstractArray)= σ.(logits(M, X))
M = PyTorchNetwork(py"model")import CounterfactualExplanations.Generators: ∂ℓ
using LinearAlgebra
# Countefactual loss:
function ∂ℓ(generator::AbstractGradientBasedGenerator, counterfactual_state::CounterfactualState)
M = counterfactual_state.M
nn = M.nn
x′ = counterfactual_state.x′
t = counterfactual_state.target_encoded
x = reshape(x′, 1, length(x′))
py"""
x = torch.Tensor($x)
x.requires_grad = True
t = torch.Tensor($[t]).squeeze()
output = $nn(x).squeeze()
obj_loss = nn.BCEWithLogitsLoss()(output,t)
obj_loss.backward()
"""
grad = vec(py"x.grad.detach().numpy()")
return grad
end# Abstract suptype:
abstract type AbstractDropoutGenerator <: AbstractGradientBasedGenerator end
# Constructor:
struct DropoutGenerator <: AbstractDropoutGenerator
loss::Symbol # loss function
complexity::Function # complexity function
mutability::Union{Nothing,Vector{Symbol}} # mutibility constraints
λ::AbstractFloat # strength of penalty
ϵ::AbstractFloat # step size
τ::AbstractFloat # tolerance for convergence
p_dropout::AbstractFloat # dropout rate
end
# Instantiate:
using LinearAlgebra
generator = DropoutGenerator(
:logitbinarycrossentropy,
norm,
nothing,
0.1,
0.1,
1e-5,
0.5
)import CounterfactualExplanations.Generators: generate_perturbations, ∇
using StatsBase
function generate_perturbations(generator::AbstractDropoutGenerator, counterfactual_state::CounterfactualState)
𝐠ₜ = ∇(generator, counterfactual_state) # gradient
# Dropout:
set_to_zero = sample(1:length(𝐠ₜ),Int(round(generator.p_dropout*length(𝐠ₜ))),replace=false)
𝐠ₜ[set_to_zero] .= 0
Δx′ = - (generator.ϵ .* 𝐠ₜ) # gradient step
return Δx′
endThis looks nice 🤓
And this … ugh 🥴
Flux, torch, tensorflow) and other differentiable models.What happens once AR has actually been implemented? 👀
Explaining black-box models through counterfactuals